/**
 * \file lite/load_and_run/src/models/model_mdl.cpp
 *
 * This file is part of MegEngine, a deep learning framework developed by
 * Megvii.
 *
 * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
 */

#include "model_mdl.h"
#include <gflags/gflags.h>
#include <iostream>

DECLARE_bool(share_param_mem);

using namespace lar;

ModelMdl::ModelMdl(const std::string& path) : model_path(path) {
    mgb_log_warn("creat mdl model use XPU as default comp node");
    m_load_config.comp_graph = mgb::ComputingGraph::make();
    m_load_config.comp_graph->options().graph_opt_level = 0;
    testcase_num = 0;
}

void ModelMdl::load_model() {
    //! read dump file
    if (share_model_mem) {
        mgb_log_warn("enable share model memory");
        FILE* fin = fopen(model_path.c_str(), "rb");
        mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
        fseek(fin, 0, SEEK_END);
        size_t size = ftell(fin);
        fseek(fin, 0, SEEK_SET);

        void* ptr = malloc(size);
        std::shared_ptr<void> buf{ptr, free};
        auto nr = fread(buf.get(), 1, size, fin);
        mgb_assert(nr == size, "read model file failed");
        fclose(fin);

        m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size);
    } else {
        m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
    }

    //! get dump_with_testcase model testcase number
    char magic[8];
    m_model_file->read(magic, sizeof(magic));
    if (strncmp(magic, "mgbtest0", 8)) {
        m_model_file->rewind();
    } else {
        m_model_file->read(&testcase_num, sizeof(testcase_num));
    }

    m_format =
            mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file);
    mgb_assert(
            m_format.valid(),
            "invalid format, please make sure model is dumped by GraphDumper");

    //! load computing graph of model
    m_loader = mgb::serialization::GraphLoader::make(
            std::move(m_model_file), m_format.val());
    m_load_result = m_loader->load(m_load_config, false);
    m_load_config.comp_graph.reset();

    // get testcase input generated by dump_with_testcase.py
    if (testcase_num) {
        for (auto&& i : m_load_result.tensor_map) {
            test_input_tensors.emplace_back(i.first, i.second.get());
        }
        std::sort(test_input_tensors.begin(), test_input_tensors.end());
    }
    // initialize output callback
    for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
        mgb::ComputingGraph::Callback cb;
        m_callbacks.push_back(cb);
    }
}

void ModelMdl::make_output_spec() {
    for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
        auto item = m_load_result.output_var_list[i];
        m_output_spec.emplace_back(item, std::move(m_callbacks[i]));
    }

    m_asyc_exec = m_load_result.graph_compile(m_output_spec);
}

std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader(
        std::unique_ptr<mgb::serialization::InputFile> input_file) {
    if (input_file) {
        m_loader = mgb::serialization::GraphLoader::make(
                std::move(input_file), m_loader->format());
    } else {
        m_loader = mgb::serialization::GraphLoader::make(
                m_loader->reset_file(), m_loader->format());
    }
    return m_loader;
}

void ModelMdl::run_model() {
    mgb_assert(
            m_asyc_exec != nullptr,
            "empty asychronous function to execute after graph compiled");
    m_asyc_exec->execute();
}

void ModelMdl::wait() {
    m_asyc_exec->wait();
}
